# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import os

import pytest

from haystack import Document, Pipeline
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.query import QueryExpander
from haystack.components.retrievers import InMemoryBM25Retriever, MultiQueryTextRetriever
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy


class TestMultiQueryTextRetriever:
    @pytest.fixture
    def sample_documents(self):
        return [
            Document(
                content="Renewable energy is energy that is collected from renewable resources.",
                meta={"category": None},
            ),
            Document(
                content="Solar energy is a type of green energy that is harnessed from the sun.",
                meta={"category": "solar"},
            ),
            Document(
                content="Wind energy is another type of green energy that is generated by wind turbines",
                meta={"category": "wind"},
            ),
            Document(
                content="Hydropower is a form of renewable energy using the flow of water to generate electricity.",
                meta={"category": "hydro"},
            ),
            Document(
                content="Geothermal energy is heat that comes from the sub-surface of the earth.",
                meta={"category": "geo"},
            ),
            Document(
                content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
                meta={"category": "fossil"},
            ),
            Document(
                content="Nuclear energy is produced through nuclear reactions, typically using uranium or plutonium "
                "as fuel.",
                meta={"category": "nuclear"},
            ),
        ]

    @pytest.fixture
    def document_store_with_docs(self, sample_documents):
        document_store = InMemoryDocumentStore()
        doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)
        doc_writer.run(documents=sample_documents)
        return document_store

    def test_init_with_default_parameters(self):
        in_memory_retriever = InMemoryBM25Retriever(document_store=InMemoryDocumentStore())
        retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
        assert retriever.retriever == in_memory_retriever
        assert retriever.max_workers == 3

    def test_init_with_custom_parameters(self):
        in_memory_retriever = InMemoryBM25Retriever(document_store=InMemoryDocumentStore())
        retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2)
        assert retriever.retriever == in_memory_retriever
        assert retriever.max_workers == 2

    def test_run_with_multiple_queries(self, document_store_with_docs):
        in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
        multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
        queries = ["renewable energy", "solar power", "wind turbines"]
        multi_retriever.warm_up()
        result = multi_retriever.run(queries=queries)

        assert "documents" in result
        assert len(result["documents"]) > 0
        assert all(isinstance(doc, Document) for doc in result["documents"])
        scores = [doc.score for doc in result["documents"] if doc.score is not None]
        assert scores == sorted(scores, reverse=True)

    def test_to_dict(self):
        in_memory_retriever = InMemoryBM25Retriever(document_store=InMemoryDocumentStore())
        multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=2)
        result = multi_retriever.to_dict()
        assert "type" in result
        assert "init_parameters" in result
        assert result["init_parameters"]["max_workers"] == 2
        assert "retriever" in result["init_parameters"]
        assert (
            result["init_parameters"]["retriever"]["type"]
            == "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever"
        )

    def test_from_dict(self):
        data = {
            "type": "haystack.components.retrievers.multi_query_text_retriever.MultiQueryTextRetriever",
            "init_parameters": {
                "retriever": {
                    "type": "haystack.components.retrievers.in_memory.bm25_retriever.InMemoryBM25Retriever",
                    "init_parameters": {
                        "document_store": {
                            "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
                            "init_parameters": {
                                "bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
                                "bm25_algorithm": "BM25L",
                                "bm25_parameters": {},
                                "embedding_similarity_function": "dot_product",
                                "index": "88144fa9-6e45-4e5d-8647-4c4002d8b6db",
                                # 'return_embedding': True  # ToDo: investigate why this fails
                            },
                        },
                        "filters": None,
                        "top_k": 10,
                        "scale_score": False,
                        "filter_policy": "replace",
                    },
                },
                "max_workers": 3,
            },
        }

        result = MultiQueryTextRetriever.from_dict(data)

        assert isinstance(result, MultiQueryTextRetriever)
        assert result.retriever.__class__.__name__ == "InMemoryBM25Retriever"
        assert result.max_workers == 3

    @pytest.mark.integration
    def test_run_with_filters(self, document_store_with_docs):
        in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
        filters = {"field": "category", "operator": "==", "value": "solar"}
        multi_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever)
        result = multi_retriever.run(queries=["energy"], retriever_kwargs={"filters": filters})
        assert "documents" in result
        assert all(doc.meta.get("category") == "solar" for doc in result["documents"])

    @pytest.mark.skipif(
        not os.environ.get("OPENAI_API_KEY", None),
        reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
    )
    @pytest.mark.integration
    def test_pipeline_integration(self, document_store_with_docs):
        pipeline = Pipeline()
        expander = QueryExpander(
            chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), n_expansions=3, include_original_query=True
        )
        in_memory_retriever = InMemoryBM25Retriever(document_store=document_store_with_docs)
        multiquery_retriever = MultiQueryTextRetriever(retriever=in_memory_retriever, max_workers=3)
        pipeline.add_component("query_expander", expander)
        pipeline.add_component("multiquery_retriever", multiquery_retriever)
        pipeline.connect("query_expander.queries", "multiquery_retriever.queries")

        data = {
            "query_expander": {"query": "green energy sources"},
            "multiquery_retriever": {"retriever_kwargs": {"top_k": 3}},
        }
        results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"})

        assert "multiquery_retriever" in results
        assert "documents" in results["multiquery_retriever"]
        assert len(results["multiquery_retriever"]["documents"]) > 0
        assert "query_expander" in results
        assert "queries" in results["query_expander"]
        assert len(results["query_expander"]["queries"]) == 4

        # assert that documents are sorted by score (highest first)
        scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None]
        assert scores == sorted(scores, reverse=True)

        # assert there are not duplicates
        contents = [doc.content for doc in results["multiquery_retriever"]["documents"]]
        assert len(contents) == len(set(contents))
